//! Unified Worker Management Module //! //! Handles all aspects of worker lifecycle including discovery, initialization, //! runtime management, and health monitoring. use crate::config::types::{ CircuitBreakerConfig as ConfigCircuitBreakerConfig, ConnectionMode as ConfigConnectionMode, HealthCheckConfig, RouterConfig, RoutingMode, }; use crate::core::{ BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, DPAwareWorkerBuilder, HealthConfig, Worker, WorkerFactory, WorkerRegistry, WorkerType, }; use crate::grpc_client::SglangSchedulerClient; use crate::policies::PolicyRegistry; use crate::protocols::worker_spec::{ FlushCacheResult, WorkerConfigRequest, WorkerLoadInfo, WorkerLoadsResult, }; use crate::server::AppContext; use futures::future; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use tokio::sync::{watch, Mutex}; use tokio::task::JoinHandle; use tracing::{debug, error, info, warn}; static HTTP_CLIENT: Lazy = Lazy::new(|| { reqwest::Client::builder() .timeout(Duration::from_secs(10)) .build() .expect("Failed to create HTTP client") }); /// Server information returned from worker endpoints #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ServerInfo { pub model_id: Option, pub model_path: Option, pub dp_size: Option, pub version: Option, pub max_batch_size: Option, pub max_total_tokens: Option, pub max_prefill_tokens: Option, pub max_running_requests: Option, pub max_num_reqs: Option, } /// DP (Data Parallel) information for a worker #[derive(Debug, Clone)] pub struct DpInfo { pub dp_size: usize, pub model_id: String, } /// Worker discovery results gathered from backend endpoints struct WorkerDiscovery { labels: HashMap, grpc_client: Option, } impl WorkerDiscovery { fn new() -> Self { Self { labels: HashMap::new(), grpc_client: None, } } } /// Unified worker management pub struct WorkerManager; impl WorkerManager { /// Get server info from /get_server_info endpoint pub async fn get_server_info(url: &str, api_key: Option<&str>) -> Result { let base_url = url.trim_end_matches('/'); let server_info_url = format!("{}/get_server_info", base_url); let mut req = HTTP_CLIENT.get(&server_info_url); if let Some(key) = api_key { req = req.bearer_auth(key); } let response = req .send() .await .map_err(|e| format!("Failed to connect to {}: {}", server_info_url, e))?; if !response.status().is_success() { return Err(format!( "Server returned status {} from {}", response.status(), server_info_url )); } let json = response .json::() .await .map_err(|e| format!("Failed to parse response from {}: {}", server_info_url, e))?; info!( "Successfully retrieved server info from {}", server_info_url ); Self::parse_server_info(json) } /// Get model info from /get_model_info endpoint pub async fn get_model_info(url: &str, api_key: Option<&str>) -> Result { let base_url = url.trim_end_matches('/'); let model_info_url = format!("{}/get_model_info", base_url); let mut req = HTTP_CLIENT.get(&model_info_url); if let Some(key) = api_key { req = req.bearer_auth(key); } let response = req .send() .await .map_err(|e| format!("Failed to connect to {}: {}", model_info_url, e))?; if !response.status().is_success() { return Err(format!( "Server returned status {} from {}", response.status(), model_info_url )); } let json = response .json::() .await .map_err(|e| format!("Failed to parse response from {}: {}", model_info_url, e))?; info!("Successfully retrieved model info from {}", model_info_url); Ok(json) } /// Get DP info for a worker URL pub async fn get_dp_info(url: &str, api_key: Option<&str>) -> Result { let info = Self::get_server_info(url, api_key).await?; let dp_size = info .dp_size .ok_or_else(|| format!("No dp_size in response from {}", url))?; let model_id = info .model_id .or_else(|| { info.model_path .and_then(|path| path.split('/').next_back().map(|s| s.to_string())) }) .unwrap_or_else(|| "unknown".to_string()); Ok(DpInfo { dp_size, model_id }) } /// Generate DP-aware worker URLs pub async fn get_dp_aware_urls( base_urls: &[String], api_key: Option<&str>, ) -> Result, String> { let mut dp_urls = Vec::new(); for base_url in base_urls { match Self::get_dp_info(base_url, api_key).await { Ok(dp_info) => { info!( "Discovered DP size {} for {} (model: {})", dp_info.dp_size, base_url, dp_info.model_id ); for rank in 0..dp_info.dp_size { dp_urls.push(format!("{}@{}", base_url, rank)); } } Err(e) => { return Err(format!("Failed to get DP info from {}: {}", base_url, e)); } } } Ok(dp_urls) } /// Initialize workers from configuration at startup pub async fn initialize_workers( config: &RouterConfig, registry: &Arc, policy_registry: Option<&Arc>, ) -> Result<(), String> { info!("Starting worker initialization"); // Determine connection mode from config let connection_mode = &config.connection_mode; match &config.mode { RoutingMode::Regular { worker_urls } => match connection_mode { ConfigConnectionMode::Http => { Self::initialize_regular_workers( worker_urls, config, registry, policy_registry, ) .await?; } ConfigConnectionMode::Grpc => { Self::initialize_grpc_workers(worker_urls, config, registry, policy_registry) .await?; } }, RoutingMode::PrefillDecode { prefill_urls, decode_urls, .. } => match connection_mode { ConfigConnectionMode::Http => { let prefill_entries: Vec<(&String, &Option)> = prefill_urls.iter().map(|(url, port)| (url, port)).collect(); Self::initialize_prefill_workers( &prefill_entries, config, registry, policy_registry, ) .await?; Self::initialize_decode_workers(decode_urls, config, registry, policy_registry) .await?; } ConfigConnectionMode::Grpc => { Self::initialize_grpc_pd_workers( prefill_urls, decode_urls, config, registry, policy_registry, ) .await?; } }, RoutingMode::OpenAI { .. } => { info!("OpenAI routing mode - no workers to initialize"); } } Self::wait_for_healthy_workers( registry, config.worker_startup_timeout_secs, config.health_check.check_interval_secs, ) .await?; info!("Worker initialization completed successfully"); Ok(()) } /// Initialize regular workers async fn initialize_regular_workers( urls: &[String], config: &RouterConfig, registry: &Arc, policy_registry: Option<&Arc>, ) -> Result<(), String> { info!("Creating {} regular workers", urls.len()); let connection_mode = Self::convert_connection_mode(&config.connection_mode, urls.first()); let circuit_breaker_config = Self::convert_circuit_breaker_config(&config.effective_circuit_breaker_config()); let health_config = Self::convert_health_config(&config.health_check); let mut registered_workers: HashMap>> = HashMap::new(); for url in urls { if config.dp_aware { match Self::get_dp_info(url, config.api_key.as_deref()).await { Ok(dp_info) => { info!( "Discovered DP-aware worker {} with size {}", url, dp_info.dp_size ); for rank in 0..dp_info.dp_size { let mut builder = DPAwareWorkerBuilder::new(url.clone(), rank, dp_info.dp_size) .worker_type(WorkerType::Regular) .connection_mode(connection_mode.clone()) .circuit_breaker_config(circuit_breaker_config.clone()) .health_config(health_config.clone()); if let Some(ref key) = config.api_key { builder = builder.api_key(key.clone()); } let worker = Arc::new(builder.build()) as Arc; let model_id = worker.model_id(); let worker_id = registry.register(Arc::clone(&worker)); info!( "Registered DP-aware worker {}@{} with ID {:?}", url, rank, worker_id ); registered_workers .entry(model_id.to_string()) .or_default() .push(Arc::clone(&worker)); if let Some(policy_reg) = policy_registry { policy_reg.on_worker_added(model_id, None); } } } Err(e) => { return Err(format!( "Failed to get DP info for worker {}: {}. DP-aware mode requires all workers to support DP.", url, e )); } } } else { let worker = Self::create_basic_worker( url.clone(), WorkerType::Regular, connection_mode.clone(), config.api_key.clone(), None, circuit_breaker_config.clone(), health_config.clone(), ) .await; Self::register_worker(worker, registry, &mut registered_workers, policy_registry); } } Self::initialize_cache_policies(®istered_workers, registry, policy_registry); Ok(()) } /// Initialize prefill workers for PD mode async fn initialize_prefill_workers( prefill_entries: &[(&String, &Option)], config: &RouterConfig, registry: &Arc, policy_registry: Option<&Arc>, ) -> Result<(), String> { info!("Creating {} prefill workers", prefill_entries.len()); let connection_mode = Self::convert_connection_mode( &config.connection_mode, prefill_entries.first().map(|(url, _)| *url), ); let circuit_breaker_config = Self::convert_circuit_breaker_config(&config.effective_circuit_breaker_config()); let health_config = Self::convert_health_config(&config.health_check); let mut registered_workers: HashMap>> = HashMap::new(); // TODO: Add proper DP-aware support for prefill workers in PD mode if config.dp_aware { warn!("DP-aware mode is not yet supported for prefill workers in PD mode. Creating regular prefill workers instead."); } for (url, bootstrap_port) in prefill_entries { let worker_type = WorkerType::Prefill { bootstrap_port: **bootstrap_port, }; let worker = Self::create_basic_worker( (*url).clone(), worker_type, connection_mode.clone(), config.api_key.clone(), None, circuit_breaker_config.clone(), health_config.clone(), ) .await; Self::register_worker(worker, registry, &mut registered_workers, policy_registry); } if let Some(policy_reg) = policy_registry { let all_prefill_workers: Vec> = registered_workers .values() .flat_map(|workers| workers.iter().cloned()) .collect(); policy_reg.init_pd_cache_aware_policies(&all_prefill_workers, &[]); } Ok(()) } /// Initialize decode workers for PD mode async fn initialize_decode_workers( urls: &[String], config: &RouterConfig, registry: &Arc, policy_registry: Option<&Arc>, ) -> Result<(), String> { info!("Creating {} decode workers", urls.len()); let connection_mode = Self::convert_connection_mode(&config.connection_mode, urls.first()); let circuit_breaker_config = Self::convert_circuit_breaker_config(&config.effective_circuit_breaker_config()); let health_config = Self::convert_health_config(&config.health_check); let mut registered_workers: HashMap>> = HashMap::new(); // TODO: Add proper DP-aware support for decode workers in PD mode if config.dp_aware { warn!("DP-aware mode is not yet supported for decode workers in PD mode. Creating regular decode workers instead."); } for url in urls { let worker = Self::create_basic_worker( url.clone(), WorkerType::Decode, connection_mode.clone(), config.api_key.clone(), None, circuit_breaker_config.clone(), health_config.clone(), ) .await; Self::register_worker(worker, registry, &mut registered_workers, policy_registry); } if let Some(policy_reg) = policy_registry { let all_decode_workers: Vec> = registered_workers .values() .flat_map(|workers| workers.iter().cloned()) .collect(); policy_reg.init_pd_cache_aware_policies(&[], &all_decode_workers); } Ok(()) } /// Initialize gRPC workers for regular mode async fn initialize_grpc_workers( urls: &[String], config: &RouterConfig, registry: &Arc, policy_registry: Option<&Arc>, ) -> Result<(), String> { info!("Creating {} gRPC regular workers", urls.len()); let circuit_breaker_config = Self::convert_circuit_breaker_config(&config.effective_circuit_breaker_config()); let health_config = Self::convert_health_config(&config.health_check); let connection_mode = ConnectionMode::Grpc { port: None }; let mut registered_workers: HashMap>> = HashMap::new(); for url in urls { let worker = Self::create_basic_worker( url.clone(), WorkerType::Regular, connection_mode.clone(), config.api_key.clone(), None, circuit_breaker_config.clone(), health_config.clone(), ) .await; Self::register_worker(worker, registry, &mut registered_workers, policy_registry); info!( "Registered gRPC worker at {} (will connect on first use)", url ); } Self::initialize_cache_policies(®istered_workers, registry, policy_registry); Ok(()) } /// Initialize gRPC PD (Prefill-Decode) workers async fn initialize_grpc_pd_workers( prefill_urls: &[(String, Option)], decode_urls: &[String], config: &RouterConfig, registry: &Arc, policy_registry: Option<&Arc>, ) -> Result<(), String> { info!( "Creating {} gRPC prefill workers and {} gRPC decode workers", prefill_urls.len(), decode_urls.len() ); let circuit_breaker_config = Self::convert_circuit_breaker_config(&config.effective_circuit_breaker_config()); let health_config = Self::convert_health_config(&config.health_check); let mut registered_prefill_workers: HashMap>> = HashMap::new(); let mut registered_decode_workers: HashMap>> = HashMap::new(); for (url, bootstrap_port) in prefill_urls { let worker_type = WorkerType::Prefill { bootstrap_port: *bootstrap_port, }; let connection_mode = ConnectionMode::Grpc { port: *bootstrap_port, }; let worker = Self::create_basic_worker( url.clone(), worker_type, connection_mode, config.api_key.clone(), None, circuit_breaker_config.clone(), health_config.clone(), ) .await; Self::register_worker( worker, registry, &mut registered_prefill_workers, policy_registry, ); info!( "Registered gRPC prefill worker at {} (will connect on first use)", url ); } // Create decode workers for url in decode_urls { let connection_mode = ConnectionMode::Grpc { port: None }; let worker = Self::create_basic_worker( url.clone(), WorkerType::Decode, connection_mode, config.api_key.clone(), None, circuit_breaker_config.clone(), health_config.clone(), ) .await; Self::register_worker( worker, registry, &mut registered_decode_workers, policy_registry, ); info!( "Registered gRPC decode worker at {} (will connect on first use)", url ); } if let Some(policy_reg) = policy_registry { let all_prefill_workers: Vec> = registered_prefill_workers .values() .flat_map(|workers| workers.iter().cloned()) .collect(); let all_decode_workers: Vec> = registered_decode_workers .values() .flat_map(|workers| workers.iter().cloned()) .collect(); policy_reg.init_pd_cache_aware_policies(&all_prefill_workers, &all_decode_workers); } Ok(()) } /// Add a worker from a configuration request /// /// Registers worker immediately with healthy=false, returns worker for async validation pub async fn add_worker_from_config( config: &WorkerConfigRequest, context: &AppContext, ) -> Result, String> { // Check if worker already exists if context.worker_registry.get_by_url(&config.url).is_some() { return Err(format!("Worker {} already exists", config.url)); } let mut labels = config.labels.clone(); if let Some(model_id) = &config.model_id { labels.insert("model_id".to_string(), model_id.clone()); } if let Some(priority) = config.priority { labels.insert("priority".to_string(), priority.to_string()); } if let Some(cost) = config.cost { labels.insert("cost".to_string(), cost.to_string()); } if let Some(ref tokenizer_path) = config.tokenizer_path { labels.insert("tokenizer_path".to_string(), tokenizer_path.clone()); } if let Some(ref reasoning_parser) = config.reasoning_parser { labels.insert("reasoning_parser".to_string(), reasoning_parser.clone()); } if let Some(ref tool_parser) = config.tool_parser { labels.insert("tool_parser".to_string(), tool_parser.clone()); } if let Some(ref chat_template) = config.chat_template { labels.insert("chat_template".to_string(), chat_template.clone()); } let worker_type = config .worker_type .as_ref() .map(|t| match t.as_str() { "prefill" => WorkerType::Prefill { bootstrap_port: config.bootstrap_port, }, "decode" => WorkerType::Decode, _ => WorkerType::Regular, }) .unwrap_or(WorkerType::Regular); let connection_mode = if config.url.starts_with("grpc://") { ConnectionMode::Grpc { port: None } } else { ConnectionMode::Http }; let circuit_breaker_config = Self::convert_circuit_breaker_config( &context.router_config.effective_circuit_breaker_config(), ); let health_config = Self::convert_health_config(&context.router_config.health_check); // Create and register worker (starts with healthy=false) let worker = Self::create_basic_worker( config.url.clone(), worker_type, connection_mode, config.api_key.clone(), Some(labels.clone()), circuit_breaker_config, health_config, ) .await; worker.set_healthy(false); context.worker_registry.register(worker.clone()); let policy_hint = labels.get("policy").map(|s| s.as_str()); let model_id = worker.model_id().to_string(); context .policy_registry .on_worker_added(&model_id, policy_hint); info!("Registered worker {} (initializing)", config.url); // Return worker for async validation Ok(worker) } /// Validate and activate a worker (for async validation after registration) pub async fn validate_and_activate_worker( worker: &Arc, context: &AppContext, ) -> Result { let url = worker.url(); // Perform health validation WorkerFactory::validate_health(url, context.router_config.worker_startup_timeout_secs) .await .map_err(|e| format!("Health check failed for {}: {}", url, e))?; // Mark as healthy worker.set_healthy(true); info!("Worker {} validated and activated", url); Ok(format!("Worker {} is now healthy", url)) } /// Add a worker from URL (legacy endpoint) pub async fn add_worker( url: &str, api_key: &Option, context: &AppContext, ) -> Result { Self::add_worker_internal( url, WorkerType::Regular, ConnectionMode::Http, api_key.clone(), None, None, context, ) .await } /// Remove a worker pub fn remove_worker(url: &str, context: &AppContext) -> Result { if context.router_config.dp_aware { Self::remove_dp_aware_workers(url, context) } else { Self::remove_single_worker(url, context) } } pub fn get_worker_urls(registry: &Arc) -> Vec { registry .get_all() .iter() .map(|w| w.url().to_string()) .collect() } /// Internal method to add a worker with all parameters async fn add_worker_internal( worker_url: &str, worker_type: WorkerType, connection_mode: ConnectionMode, api_key: Option, labels: Option>, policy_hint: Option<&str>, context: &AppContext, ) -> Result { WorkerFactory::validate_health( worker_url, context.router_config.worker_startup_timeout_secs, ) .await .map_err(|e| format!("Health check failed: {}", e))?; let circuit_breaker_config = Self::convert_circuit_breaker_config( &context.router_config.effective_circuit_breaker_config(), ); let health_config = Self::convert_health_config(&context.router_config.health_check); if context.router_config.dp_aware { let dp_urls = Self::get_dp_aware_urls( &[worker_url.to_string()], context.router_config.api_key.as_deref(), ) .await?; let mut workers_added = 0; let mut model_workers: HashMap>> = HashMap::new(); let dp_size_for_base = dp_urls.len(); for (rank, dp_url) in dp_urls.iter().enumerate() { if context.worker_registry.get_by_url(dp_url).is_some() { info!("Worker {} already exists, skipping", dp_url); continue; } let base_url = dp_url.split('@').next().unwrap().to_string(); let mut builder = DPAwareWorkerBuilder::new(base_url, rank, dp_size_for_base) .worker_type(worker_type.clone()) .connection_mode(connection_mode.clone()) .circuit_breaker_config(circuit_breaker_config.clone()) .health_config(health_config.clone()); if let Some(ref key) = api_key { builder = builder.api_key(key.clone()); } if let Some(ref worker_labels) = labels { builder = builder.labels(worker_labels.clone()); } let worker = Arc::new(builder.build()) as Arc; let model_id = worker.model_id().to_string(); context.worker_registry.register(worker.clone()); workers_added += 1; model_workers .entry(model_id.clone()) .or_default() .push(worker); context .policy_registry .on_worker_added(&model_id, policy_hint); } for model_id in model_workers.keys() { let all_model_workers = context.worker_registry.get_by_model_fast(model_id); if let Some(policy) = context.policy_registry.get_policy(model_id) { if policy.name() == "cache_aware" { context .policy_registry .init_cache_aware_policy(model_id, &all_model_workers); } } } if workers_added == 0 { Ok(format!("All DP workers already exist for {}", worker_url)) } else { Ok(format!( "Added {} DP-aware workers for {}", workers_added, worker_url )) } } else { if context.worker_registry.get_by_url(worker_url).is_some() { return Err(format!("Worker {} already exists", worker_url)); } let worker = Self::create_basic_worker( worker_url.to_string(), worker_type, connection_mode, api_key, labels, circuit_breaker_config, health_config, ) .await; let model_id = worker.model_id().to_string(); context.worker_registry.register(worker.clone()); context .policy_registry .on_worker_added(&model_id, policy_hint); let workers = context.worker_registry.get_by_model_fast(&model_id); if let Some(policy) = context.policy_registry.get_policy(&model_id) { if policy.name() == "cache_aware" { context .policy_registry .init_cache_aware_policy(&model_id, &workers); } } Ok(format!("Worker {} added successfully", worker_url)) } } /// Remove a single worker fn remove_single_worker(worker_url: &str, context: &AppContext) -> Result { let worker = context .worker_registry .get_by_url(worker_url) .ok_or_else(|| format!("Worker {} not found", worker_url))?; let model_id = worker.model_id().to_string(); context .policy_registry .remove_worker_from_cache_aware(&model_id, worker_url); context.worker_registry.remove_by_url(worker_url); context.policy_registry.on_worker_removed(&model_id); let remaining_workers = context.worker_registry.get_by_model_fast(&model_id); if let Some(policy) = context.policy_registry.get_policy(&model_id) { if policy.name() == "cache_aware" && !remaining_workers.is_empty() { context .policy_registry .init_cache_aware_policy(&model_id, &remaining_workers); } } Ok(format!("Worker {} removed successfully", worker_url)) } /// Remove DP-aware workers with prefix matching fn remove_dp_aware_workers(worker_url: &str, context: &AppContext) -> Result { let worker_url_prefix = format!("{}@", worker_url); let mut removed_workers = Vec::new(); let mut affected_models = std::collections::HashSet::new(); let all_workers = context.worker_registry.get_all(); for worker in all_workers.iter() { if worker.url().starts_with(&worker_url_prefix) { let model_id = worker.model_id().to_string(); affected_models.insert(model_id.clone()); context .policy_registry .remove_worker_from_cache_aware(&model_id, worker.url()); if context .worker_registry .remove_by_url(worker.url()) .is_some() { removed_workers.push(worker.url().to_string()); context.policy_registry.on_worker_removed(&model_id); } } } for model_id in affected_models { let remaining_workers = context.worker_registry.get_by_model_fast(&model_id); if let Some(policy) = context.policy_registry.get_policy(&model_id) { if policy.name() == "cache_aware" && !remaining_workers.is_empty() { context .policy_registry .init_cache_aware_policy(&model_id, &remaining_workers); } } } if removed_workers.is_empty() { Err(format!( "No workers found with prefix {}", worker_url_prefix )) } else { Ok(format!( "Removed {} DP-aware workers: {:?}", removed_workers.len(), removed_workers )) } } /// Create a basic worker async fn create_basic_worker( url: String, worker_type: WorkerType, connection_mode: ConnectionMode, api_key: Option, labels: Option>, circuit_breaker_config: CircuitBreakerConfig, health_config: HealthConfig, ) -> Arc { let discovery = Self::discover_worker_metadata(&url, &connection_mode, api_key.as_deref()).await; let mut final_labels = discovery.labels; if let Some(custom_labels) = labels { for (key, value) in custom_labels { final_labels.insert(key, value); } } let mut builder = BasicWorkerBuilder::new(url) .worker_type(worker_type) .connection_mode(connection_mode) .circuit_breaker_config(circuit_breaker_config) .health_config(health_config); if let Some(key) = api_key { builder = builder.api_key(key); } if !final_labels.is_empty() { builder = builder.labels(final_labels); } if let Some(client) = discovery.grpc_client { builder = builder.grpc_client(client); } let worker = builder.build(); Arc::new(worker) as Arc } /// Register a worker and update policies fn register_worker( worker: Arc, registry: &Arc, registered_workers: &mut HashMap>>, policy_registry: Option<&Arc>, ) { let model_id = worker.model_id(); let url = worker.url(); let worker_id = registry.register(Arc::clone(&worker)); info!("Registered worker {} with ID {:?}", url, worker_id); registered_workers .entry(model_id.to_string()) .or_default() .push(Arc::clone(&worker)); if let Some(policy_reg) = policy_registry { policy_reg.on_worker_added(model_id, None); } } /// Initialize cache-aware policies fn initialize_cache_policies( registered_workers: &HashMap>>, registry: &Arc, policy_registry: Option<&Arc>, ) { if let Some(policy_reg) = policy_registry { for model_id in registered_workers.keys() { let all_model_workers = registry.get_by_model_fast(model_id); if let Some(policy) = policy_reg.get_policy(model_id) { if policy.name() == "cache_aware" { policy_reg.init_cache_aware_policy(model_id, &all_model_workers); } } } } } /// Wait for workers to become healthy async fn wait_for_healthy_workers( registry: &Arc, timeout_secs: u64, check_interval_secs: u64, ) -> Result<(), String> { let timeout = Duration::from_secs(timeout_secs); let check_interval = Duration::from_secs(check_interval_secs); let start_time = std::time::Instant::now(); info!( "Waiting for workers to become healthy (timeout: {}s)", timeout_secs ); let workers = registry.get_all(); if workers.is_empty() { info!("No workers to wait for, continuing"); return Ok(()); } // Mark all workers as unhealthy initially info!( "Marking {} workers as unhealthy before health checks", workers.len() ); for worker in &workers { worker.set_healthy(false); } loop { // 1. Filter unhealthy workers let workers = registry.get_all(); let unhealthy_workers: Vec<_> = workers .iter() .filter(|w| !w.is_healthy()) .cloned() .collect(); // 2. If all workers are healthy, return immediately if unhealthy_workers.is_empty() { let healthy_urls: Vec<_> = workers.iter().map(|w| w.url().to_string()).collect(); info!( "All {} workers are healthy: {:?}", workers.len(), healthy_urls ); return Ok(()); } // Check timeout if start_time.elapsed() > timeout { let healthy_workers: Vec<_> = workers .iter() .filter(|w| w.is_healthy()) .map(|w| w.url().to_string()) .collect(); let unhealthy_urls: Vec<_> = unhealthy_workers .iter() .map(|w| w.url().to_string()) .collect(); error!( "Workers failed to become healthy after {}s. Unhealthy: {:?}, Healthy: {:?}", timeout_secs, unhealthy_urls, healthy_workers ); return Err(format!( "Workers failed to become healthy after {}s. Unhealthy: {:?}", timeout_secs, unhealthy_urls )); } let unhealthy_urls: Vec<_> = unhealthy_workers .iter() .map(|w| w.url().to_string()) .collect(); info!( "Waiting for {} workers to become healthy. Unhealthy: {:?}", unhealthy_workers.len(), unhealthy_urls ); // 3. Check health of all unhealthy workers in parallel let health_check_futures: Vec<_> = unhealthy_workers .iter() .map(|worker| { let w = worker.clone(); let url = worker.url().to_string(); async move { match w.check_health_async().await { Ok(_) => { w.set_healthy(true); debug!("Worker {} now healthy", url); } Err(e) => { debug!("Worker {} health check failed: {}", url, e); } } } }) .collect(); future::join_all(health_check_futures).await; // 4. Check if all workers are now healthy after health checks let still_unhealthy: Vec<_> = workers.iter().filter(|w| !w.is_healthy()).collect(); // 5. If all workers are now healthy, return immediately without sleeping if still_unhealthy.is_empty() { let healthy_urls: Vec<_> = workers.iter().map(|w| w.url().to_string()).collect(); info!( "All {} workers are healthy: {:?}", workers.len(), healthy_urls ); return Ok(()); } // 6. Otherwise, sleep before next iteration tokio::time::sleep(check_interval).await; } } /// Gather worker metadata directly from the backend before registration. async fn discover_worker_metadata( url: &str, connection_mode: &ConnectionMode, api_key: Option<&str>, ) -> WorkerDiscovery { match connection_mode { ConnectionMode::Http => Self::discover_http_metadata(url, api_key).await, ConnectionMode::Grpc { .. } => Self::discover_grpc_metadata(url).await, } } async fn discover_http_metadata(url: &str, api_key: Option<&str>) -> WorkerDiscovery { let mut discovery = WorkerDiscovery::new(); match Self::get_model_info(url, api_key).await { Ok(model_info) => { if let Some(model_path) = model_info.get("model_path").and_then(|v| v.as_str()) { if !model_path.is_empty() { discovery .labels .insert("model_path".to_string(), model_path.to_string()); } } if let Some(tokenizer_path) = model_info.get("tokenizer_path").and_then(|v| v.as_str()) { if !tokenizer_path.is_empty() { discovery .labels .insert("tokenizer_path".to_string(), tokenizer_path.to_string()); } } if let Some(served_model_name) = model_info.get("served_model_name").and_then(|v| v.as_str()) { if !served_model_name.is_empty() { discovery.labels.insert( "served_model_name".to_string(), served_model_name.to_string(), ); } } if let Some(weight_version) = model_info.get("weight_version").and_then(|v| v.as_str()) { if !weight_version.is_empty() { discovery .labels .insert("weight_version".to_string(), weight_version.to_string()); } } if let Some(model_type) = model_info.get("model_type").and_then(|v| v.as_str()) { if !model_type.is_empty() { discovery .labels .insert("model_type".to_string(), model_type.to_string()); } } if let Some(is_generation) = model_info.get("is_generation").and_then(|v| v.as_bool()) { discovery .labels .insert("is_generation".to_string(), is_generation.to_string()); } if let Some(preferred_sampling_params) = model_info .get("preferred_sampling_params") .and_then(|v| v.as_str()) { if !preferred_sampling_params.is_empty() { discovery.labels.insert( "preferred_sampling_params".to_string(), preferred_sampling_params.to_string(), ); } } if let Some(max_context_length) = model_info .get("max_context_length") .and_then(|v| v.as_i64()) { discovery.labels.insert( "max_context_length".to_string(), max_context_length.to_string(), ); } if let Some(max_req_input_len) = model_info.get("max_req_input_len").and_then(|v| v.as_i64()) { discovery.labels.insert( "max_req_input_len".to_string(), max_req_input_len.to_string(), ); } } Err(e) => { warn!( "Worker discovery: failed to fetch HTTP model info from {}: {}", url, e ); } } match Self::get_server_info(url, api_key).await { Ok(server_info) => { if let Some(model_id) = server_info.model_id { if !model_id.is_empty() { discovery.labels.insert("model_id".to_string(), model_id); } } if let Some(model_path) = server_info.model_path { if !model_path.is_empty() { discovery .labels .insert("model_path".to_string(), model_path); } } if let Some(version) = server_info.version { if !version.is_empty() { discovery .labels .insert("server_version".to_string(), version); } } if let Some(max_total_tokens) = server_info.max_total_tokens { discovery .labels .insert("max_total_tokens".to_string(), max_total_tokens.to_string()); } if let Some(max_prefill_tokens) = server_info.max_prefill_tokens { discovery.labels.insert( "max_prefill_tokens".to_string(), max_prefill_tokens.to_string(), ); } if let Some(max_running_requests) = server_info.max_running_requests { discovery.labels.insert( "max_running_requests".to_string(), max_running_requests.to_string(), ); } } Err(e) => { warn!( "Worker discovery: failed to fetch HTTP server info from {}: {}", url, e ); } } Self::finalize_model_id(&mut discovery.labels); discovery } async fn discover_grpc_metadata(url: &str) -> WorkerDiscovery { let mut discovery = WorkerDiscovery::new(); let client = match SglangSchedulerClient::connect(url).await { Ok(client) => client, Err(e) => { warn!( "Worker discovery: failed to connect to gRPC worker {}: {}", url, e ); return discovery; } }; match client.get_model_info().await { Ok(model_info) => { if !model_info.model_path.is_empty() { discovery .labels .insert("model_path".to_string(), model_info.model_path.clone()); } if !model_info.tokenizer_path.is_empty() { discovery.labels.insert( "tokenizer_path".to_string(), model_info.tokenizer_path.clone(), ); } if !model_info.served_model_name.is_empty() { discovery.labels.insert( "served_model_name".to_string(), model_info.served_model_name.clone(), ); discovery .labels .insert("model_id".to_string(), model_info.served_model_name); } if !model_info.weight_version.is_empty() { discovery.labels.insert( "weight_version".to_string(), model_info.weight_version.clone(), ); } if !model_info.model_type.is_empty() { discovery .labels .insert("model_type".to_string(), model_info.model_type.clone()); } if !model_info.preferred_sampling_params.is_empty() { discovery.labels.insert( "preferred_sampling_params".to_string(), model_info.preferred_sampling_params.clone(), ); } discovery.labels.insert( "is_generation".to_string(), model_info.is_generation.to_string(), ); if model_info.max_context_length > 0 { discovery.labels.insert( "max_context_length".to_string(), model_info.max_context_length.to_string(), ); } if model_info.max_req_input_len > 0 { discovery.labels.insert( "max_req_input_len".to_string(), model_info.max_req_input_len.to_string(), ); } if model_info.vocab_size > 0 { discovery .labels .insert("vocab_size".to_string(), model_info.vocab_size.to_string()); } } Err(e) => { warn!( "Worker discovery: failed to fetch gRPC model info from {}: {}", url, e ); } } if !discovery.labels.contains_key("model_id") { Self::finalize_model_id(&mut discovery.labels); } discovery.grpc_client = Some(client); discovery } fn finalize_model_id(labels: &mut HashMap) { let has_model_id = labels .get("model_id") .map(|v| !v.trim().is_empty()) .unwrap_or(false); if has_model_id { return; } if let Some(served_name) = labels.get("served_model_name") { if !served_name.trim().is_empty() { labels.insert("model_id".to_string(), served_name.clone()); return; } } if let Some(model_path) = labels.get("model_path") { if !model_path.trim().is_empty() { labels.insert("model_id".to_string(), model_path.clone()); } } } /// Parse server info from JSON response fn parse_server_info(json: Value) -> Result { Ok(ServerInfo { model_id: json .get("model_id") .and_then(|v| v.as_str()) .map(String::from) .or_else(|| json.get("model").and_then(|v| v.as_str()).map(String::from)), model_path: json .get("model_path") .and_then(|v| v.as_str()) .map(String::from), dp_size: json .get("dp_size") .and_then(|v| v.as_u64()) .map(|v| v as usize), version: json .get("version") .and_then(|v| v.as_str()) .map(String::from), max_batch_size: json .get("max_batch_size") .and_then(|v| v.as_u64()) .map(|v| v as usize), max_total_tokens: json .get("max_total_tokens") .and_then(|v| v.as_u64()) .map(|v| v as usize), max_prefill_tokens: json .get("max_prefill_tokens") .and_then(|v| v.as_u64()) .map(|v| v as usize), max_running_requests: json .get("max_running_requests") .and_then(|v| v.as_u64()) .map(|v| v as usize), max_num_reqs: json .get("max_num_reqs") .and_then(|v| v.as_u64()) .map(|v| v as usize), }) } /// Convert config connection mode to core connection mode fn convert_connection_mode( config_mode: &ConfigConnectionMode, _sample_url: Option<&String>, ) -> ConnectionMode { match config_mode { ConfigConnectionMode::Http => ConnectionMode::Http, ConfigConnectionMode::Grpc => ConnectionMode::Grpc { port: None }, } } /// Convert config circuit breaker to core circuit breaker fn convert_circuit_breaker_config(config: &ConfigCircuitBreakerConfig) -> CircuitBreakerConfig { CircuitBreakerConfig { failure_threshold: config.failure_threshold, success_threshold: config.success_threshold, timeout_duration: Duration::from_secs(config.timeout_duration_secs), window_duration: Duration::from_secs(config.window_duration_secs), } } /// Convert config health check to core health config fn convert_health_config(config: &HealthCheckConfig) -> HealthConfig { HealthConfig { timeout_secs: config.timeout_secs, check_interval_secs: config.check_interval_secs, endpoint: config.endpoint.clone(), failure_threshold: config.failure_threshold, success_threshold: config.success_threshold, } } /// Flush cache on all workers /// /// Sends a POST request to /flush_cache endpoint on all HTTP workers. /// Returns detailed results showing which workers succeeded and which failed. pub async fn flush_cache_all( worker_registry: &WorkerRegistry, client: &reqwest::Client, ) -> Result { warn!("Flushing cache for ALL workers - this may impact performance temporarily"); let workers = worker_registry.get_all(); let http_workers: Vec<_> = workers .iter() .filter(|w| matches!(w.connection_mode(), ConnectionMode::Http)) .collect(); if http_workers.is_empty() { return Ok(FlushCacheResult { successful: vec![], failed: vec![], total_workers: workers.len(), http_workers: 0, message: "No HTTP workers available for cache flush".to_string(), }); } info!( "Flushing cache on {} HTTP workers (out of {} total workers)", http_workers.len(), workers.len() ); let mut tasks = Vec::new(); for worker in &http_workers { let url = worker.url().to_string(); let flush_url = format!("{}/flush_cache", url); let mut request = client.post(&flush_url); if let Some(api_key) = worker.api_key() { request = request.header("Authorization", format!("Bearer {}", api_key)); } let worker_url = url.clone(); tasks.push(async move { let result = request.send().await; (worker_url, result) }); } let results = future::join_all(tasks).await; let mut successful = Vec::new(); let mut failed = Vec::new(); for (url, result) in results { match result { Ok(response) if response.status().is_success() => { debug!("Successfully flushed cache on worker: {}", url); successful.push(url); } Ok(response) => { let error = format!("HTTP {}", response.status()); warn!("Failed to flush cache on worker {}: {}", url, error); failed.push((url, error)); } Err(e) => { let error = e.to_string(); error!("Failed to connect to worker {}: {}", url, error); failed.push((url, error)); } } } let message = if failed.is_empty() { format!( "Successfully flushed cache on all {} HTTP workers", successful.len() ) } else { format!( "Cache flush completed: {} succeeded, {} failed (out of {} HTTP workers)", successful.len(), failed.len(), http_workers.len() ) }; info!("{}", message); Ok(FlushCacheResult { successful, failed, total_workers: workers.len(), http_workers: http_workers.len(), message, }) } pub async fn get_worker_load( url: &str, api_key: Option<&str>, client: &reqwest::Client, ) -> Option { let load_url = format!("{}/get_load", url); let mut request = client.get(&load_url); if let Some(key) = api_key { request = request.bearer_auth(key); } match request.send().await { Ok(response) if response.status().is_success() => { match response.json::().await { Ok(json) => { // The /get_load endpoint returns an array of load info objects (one per DP rank) // Each object has: {dp_rank, num_reqs, num_waiting_reqs, num_tokens} if let Some(array) = json.as_array() { let total_tokens: i64 = array .iter() .filter_map(|entry| { entry.get("num_tokens").and_then(|v| v.as_i64()) }) .sum(); debug!("Worker {} load (total tokens): {}", url, total_tokens); Some(total_tokens as isize) } else { warn!( "Invalid load response from {}: expected array, got {:?}", url, json ); None } } Err(e) => { warn!("Failed to parse load response from {}: {}", url, e); None } } } Ok(response) => { warn!( "Failed to get load from {}: HTTP {}", url, response.status() ); None } Err(e) => { warn!("Failed to connect to {} for load check: {}", url, e); None } } } pub async fn get_all_worker_loads( worker_registry: &WorkerRegistry, client: &reqwest::Client, ) -> WorkerLoadsResult { let workers = worker_registry.get_all(); let total_workers = workers.len(); // Prepare tasks for parallel execution let mut tasks = Vec::new(); for worker in &workers { let url = worker.url().to_string(); let api_key = worker.api_key().clone(); let worker_type = match worker.worker_type() { WorkerType::Regular => None, WorkerType::Prefill { .. } => Some("prefill".to_string()), WorkerType::Decode => Some("decode".to_string()), }; let is_http = matches!(worker.connection_mode(), ConnectionMode::Http); let client = client.clone(); tasks.push(async move { let load = if is_http { Self::get_worker_load(&url, api_key.as_deref(), &client) .await .unwrap_or(-1) } else { -1 }; WorkerLoadInfo { worker: url, worker_type, load, } }); } let loads = future::join_all(tasks).await; let successful = loads.iter().filter(|l| l.load >= 0).count(); let failed = loads.iter().filter(|l| l.load < 0).count(); WorkerLoadsResult { loads, total_workers, successful, failed, } } } /// Load monitoring service that periodically fetches worker loads pub struct LoadMonitor { worker_registry: Arc, policy_registry: Arc, client: reqwest::Client, interval: Duration, tx: watch::Sender>, rx: watch::Receiver>, monitor_handle: Arc>>>, } impl LoadMonitor { /// Create a new load monitor pub fn new( worker_registry: Arc, policy_registry: Arc, client: reqwest::Client, interval_secs: u64, ) -> Self { let (tx, rx) = watch::channel(HashMap::new()); Self { worker_registry, policy_registry, client, interval: Duration::from_secs(interval_secs), tx, rx, monitor_handle: Arc::new(Mutex::new(None)), } } /// Start monitoring worker loads pub async fn start(&self) { let mut handle_guard = self.monitor_handle.lock().await; if handle_guard.is_some() { debug!("Load monitoring already running"); return; } info!( "Starting load monitoring with interval: {:?}", self.interval ); let worker_registry = Arc::clone(&self.worker_registry); let policy_registry = Arc::clone(&self.policy_registry); let client = self.client.clone(); let interval = self.interval; let tx = self.tx.clone(); let handle = tokio::spawn(async move { Self::monitor_loop(worker_registry, policy_registry, client, interval, tx).await; }); *handle_guard = Some(handle); } /// Stop monitoring worker loads pub async fn stop(&self) { let mut handle_guard = self.monitor_handle.lock().await; if let Some(handle) = handle_guard.take() { info!("Stopping load monitoring"); handle.abort(); let _ = handle.await; // Wait for task to finish } } /// Get a receiver for load updates pub fn subscribe(&self) -> watch::Receiver> { self.rx.clone() } /// The main monitoring loop async fn monitor_loop( worker_registry: Arc, policy_registry: Arc, client: reqwest::Client, interval: Duration, tx: watch::Sender>, ) { let mut interval_timer = tokio::time::interval(interval); loop { interval_timer.tick().await; let power_of_two_policies = policy_registry.get_all_power_of_two_policies(); if power_of_two_policies.is_empty() { debug!("No PowerOfTwo policies found, skipping load fetch"); continue; } let result = WorkerManager::get_all_worker_loads(&worker_registry, &client).await; let mut loads = HashMap::new(); for load_info in result.loads { loads.insert(load_info.worker, load_info.load); } if !loads.is_empty() { debug!( "Fetched loads from {} workers, updating {} PowerOfTwo policies", loads.len(), power_of_two_policies.len() ); for policy in &power_of_two_policies { policy.update_loads(&loads); } let _ = tx.send(loads); } else { warn!("No loads fetched from workers"); } } } /// Check if monitoring is currently active pub async fn is_running(&self) -> bool { let handle_guard = self.monitor_handle.lock().await; handle_guard.is_some() } } impl Drop for LoadMonitor { fn drop(&mut self) { if let Ok(mut handle_guard) = self.monitor_handle.try_lock() { if let Some(handle) = handle_guard.take() { handle.abort(); } } } } #[cfg(test)] mod tests { use super::*; use std::collections::HashMap; #[test] fn test_parse_server_info() { let json = serde_json::json!({ "model_id": "llama-3", "model_path": "/models/llama-3", "dp_size": 4, "version": "0.1.0" }); let info = WorkerManager::parse_server_info(json).unwrap(); assert_eq!(info.model_id, Some("llama-3".to_string())); assert_eq!(info.dp_size, Some(4)); } #[test] fn test_parse_server_info_with_fallback() { let json = serde_json::json!({ "model": "gpt-4", "dp_size": 2 }); let info = WorkerManager::parse_server_info(json).unwrap(); assert_eq!(info.model_id, Some("gpt-4".to_string())); assert_eq!(info.dp_size, Some(2)); } #[test] fn test_parse_server_info_minimal() { let json = serde_json::json!({}); let info = WorkerManager::parse_server_info(json).unwrap(); assert_eq!(info.model_id, None); assert_eq!(info.dp_size, None); } #[test] fn test_finalize_model_id_prefers_existing() { let mut labels = HashMap::new(); labels.insert("model_id".to_string(), "manual-id".to_string()); labels.insert("served_model_name".to_string(), "auto-id".to_string()); WorkerManager::finalize_model_id(&mut labels); assert_eq!(labels.get("model_id").unwrap(), "manual-id"); } #[test] fn test_finalize_model_id_prefers_served_name() { let mut labels = HashMap::new(); labels.insert("served_model_name".to_string(), "served-name".to_string()); WorkerManager::finalize_model_id(&mut labels); assert_eq!(labels.get("model_id").unwrap(), "served-name"); } #[test] fn test_finalize_model_id_falls_back_to_path() { let mut labels = HashMap::new(); labels.insert("model_path".to_string(), "/models/alpha".to_string()); WorkerManager::finalize_model_id(&mut labels); assert_eq!(labels.get("model_id").unwrap(), "/models/alpha"); } }