From fde9b96392ce3b80348b014feb23aeafc4015562 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Tue, 7 Oct 2025 19:53:10 -0400 Subject: [PATCH] [router] cleanup worker health check to return early (#11310) --- sgl-router/src/core/worker.rs | 142 ++++++++++++++------------ sgl-router/src/core/worker_manager.rs | 124 +++++++++++----------- 2 files changed, 133 insertions(+), 133 deletions(-) diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index 3f5f2bb76..b091afa26 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -9,11 +9,14 @@ use serde_json; use std::fmt; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::{Arc, LazyLock}; +use std::time::Duration; +use std::time::Instant; use tokio::sync::{Mutex, RwLock}; +use tokio::time; static WORKER_CLIENT: LazyLock = LazyLock::new(|| { reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(30)) + .timeout(Duration::from_secs(30)) .build() .expect("Failed to create worker HTTP client") }); @@ -227,6 +230,8 @@ pub trait Worker: Send + Sync + fmt::Debug { async fn reset_grpc_client(&self) -> WorkerResult<()> { Ok(()) } + async fn grpc_health_check(&self) -> WorkerResult; + async fn http_health_check(&self) -> WorkerResult; } /// Connection mode for worker communication @@ -407,66 +412,9 @@ impl Worker for BasicWorker { } async fn check_health_async(&self) -> WorkerResult<()> { - use std::time::Duration; - let health_result = match &self.metadata.connection_mode { - ConnectionMode::Http => { - let url = self.normalised_url()?; - let health_url = format!("{}{}", url, self.metadata.health_config.endpoint); - let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs); - - let mut request = WORKER_CLIENT.get(&health_url).timeout(timeout); - - if let Some(ref api_key) = self.metadata.api_key { - request = request.header("Authorization", format!("Bearer {}", api_key)); - } - - match request.send().await { - Ok(response) => response.status().is_success(), - Err(_) => false, - } - } - ConnectionMode::Grpc { .. } => { - // Use the new get_grpc_client() method for lazy initialization - match self.get_grpc_client().await { - Ok(Some(grpc_client)) => { - let mut client = grpc_client.lock().await; - match client.health_check().await { - Ok(response) => { - tracing::debug!( - "gRPC health check succeeded for {}: healthy={}", - self.metadata.url, - response.healthy - ); - response.healthy - } - Err(e) => { - tracing::warn!( - "gRPC health check RPC failed for {}: {:?}", - self.metadata.url, - e - ); - false - } - } - } - Ok(None) => { - tracing::error!( - "Worker {} is not a gRPC worker but has gRPC connection mode", - self.metadata.url - ); - false - } - Err(e) => { - tracing::error!( - "Failed to get gRPC client for worker {}: {:?}", - self.metadata.url, - e - ); - false - } - } - } + ConnectionMode::Http => self.http_health_check().await?, + ConnectionMode::Grpc { .. } => self.grpc_health_check().await?, }; if health_result { @@ -594,6 +542,61 @@ impl Worker for BasicWorker { } } } + + async fn grpc_health_check(&self) -> WorkerResult { + let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs); + let maybe = self.get_grpc_client().await?; + let Some(grpc_client) = maybe else { + tracing::error!( + "Worker {} is not a gRPC worker but connection mode is gRPC", + self.metadata.url + ); + return Ok(false); + }; + + let mut client = grpc_client.lock().await; + match time::timeout(timeout, client.health_check()).await { + Ok(Ok(resp)) => { + tracing::debug!( + "gRPC health OK for {}: healthy={}", + self.metadata.url, + resp.healthy + ); + Ok(resp.healthy) + } + Ok(Err(err)) => { + tracing::warn!("gRPC health RPC error for {}: {err:?}", self.metadata.url); + Ok(false) + } + Err(_) => { + tracing::warn!("gRPC health timed out for {}", self.metadata.url); + Ok(false) + } + } + } + + async fn http_health_check(&self) -> WorkerResult { + let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs); + + let url = self.normalised_url()?; + let health_url = format!("{}{}", url, self.metadata.health_config.endpoint); + + let mut req = WORKER_CLIENT.get(health_url).timeout(timeout); + if let Some(api_key) = &self.metadata.api_key { + req = req.bearer_auth(api_key); + } + + match req.send().await { + Ok(resp) => Ok(resp.status().is_success()), + Err(err) => { + tracing::warn!( + "HTTP health check failed for {}: {err:?}", + self.metadata.url + ); + Ok(false) + } + } + } } /// A DP-aware worker that handles data-parallel routing @@ -730,6 +733,14 @@ impl Worker for DPAwareWorker { async fn reset_grpc_client(&self) -> WorkerResult<()> { self.base_worker.reset_grpc_client().await } + + async fn grpc_health_check(&self) -> WorkerResult { + self.base_worker.grpc_health_check().await + } + + async fn http_health_check(&self) -> WorkerResult { + self.base_worker.http_health_check().await + } } /// Worker factory for creating workers of different types @@ -755,10 +766,8 @@ impl WorkerFactory { /// Static health validation before creating a worker /// This replaces wait_for_worker_health in handlers pub async fn validate_health(url: &str, timeout_secs: u64) -> WorkerResult<()> { - use std::time::Instant; - let start_time = Instant::now(); - let timeout = std::time::Duration::from_secs(timeout_secs); + let timeout = Duration::from_secs(timeout_secs); loop { if start_time.elapsed() > timeout { @@ -775,7 +784,7 @@ impl WorkerFactory { // API key authentication is handled in the worker instance's check_health_async method match WORKER_CLIENT .get(format!("{}/health", url)) - .timeout(std::time::Duration::from_secs(5)) + .timeout(Duration::from_secs(5)) .send() .await { @@ -795,7 +804,7 @@ impl WorkerFactory { } } - tokio::time::sleep(std::time::Duration::from_secs(1)).await; + tokio::time::sleep(Duration::from_secs(1)).await; } } } @@ -891,8 +900,7 @@ pub fn start_health_checker( let shutdown_clone = shutdown.clone(); let handle = tokio::spawn(async move { - let mut interval = - tokio::time::interval(tokio::time::Duration::from_secs(check_interval_secs)); + let mut interval = tokio::time::interval(Duration::from_secs(check_interval_secs)); // Counter for periodic load reset (every 10 health check cycles) let mut check_count = 0u64; diff --git a/sgl-router/src/core/worker_manager.rs b/sgl-router/src/core/worker_manager.rs index 1ee3f2e8e..0d4185440 100644 --- a/sgl-router/src/core/worker_manager.rs +++ b/sgl-router/src/core/worker_manager.rs @@ -953,113 +953,105 @@ impl WorkerManager { return Ok(()); } + // Mark all workers as unhealthy initially info!( - "Marking {} workers as unhealthy before initial health checks", + "Marking {} workers as unhealthy before health checks", workers.len() ); for worker in &workers { worker.set_healthy(false); } - info!( - "Performing initial health checks for {} workers", - workers.len() - ); - let health_check_futures: Vec<_> = workers - .iter() - .map(|worker| { - let w = worker.clone(); - let url = worker.url().to_string(); - async move { - match w.check_health_async().await { - Ok(_) => { - w.set_healthy(true); - debug!( - "Worker {} passed initial health check and marked healthy", - url - ); - Ok(url) - } - Err(e) => { - warn!("Worker {} failed initial health check: {}", url, e); - Err(url) - } - } - } - }) - .collect(); - - let health_results = future::join_all(health_check_futures).await; - let failed_checks: Vec<_> = health_results.into_iter().filter_map(|r| r.err()).collect(); - - if !failed_checks.is_empty() { - info!( - "Initial health check: {} workers failed: {:?}", - failed_checks.len(), - failed_checks - ); - } - loop { + // 1. Filter unhealthy workers let workers = registry.get_all(); - let healthy_workers: Vec<_> = workers - .iter() - .filter(|w| w.is_healthy()) - .map(|w| w.url().to_string()) - .collect(); let unhealthy_workers: Vec<_> = workers .iter() .filter(|w| !w.is_healthy()) - .map(|w| w.url().to_string()) + .cloned() .collect(); + // 2. If all workers are healthy, return immediately if unhealthy_workers.is_empty() { + let healthy_urls: Vec<_> = workers.iter().map(|w| w.url().to_string()).collect(); info!( "All {} workers are healthy: {:?}", workers.len(), - healthy_workers + healthy_urls ); return Ok(()); } + // Check timeout if start_time.elapsed() > timeout { + let healthy_workers: Vec<_> = workers + .iter() + .filter(|w| w.is_healthy()) + .map(|w| w.url().to_string()) + .collect(); + let unhealthy_urls: Vec<_> = unhealthy_workers + .iter() + .map(|w| w.url().to_string()) + .collect(); + error!( "Workers failed to become healthy after {}s. Unhealthy: {:?}, Healthy: {:?}", - timeout_secs, unhealthy_workers, healthy_workers + timeout_secs, unhealthy_urls, healthy_workers ); return Err(format!( "Workers failed to become healthy after {}s. Unhealthy: {:?}", - timeout_secs, unhealthy_workers + timeout_secs, unhealthy_urls )); } + let unhealthy_urls: Vec<_> = unhealthy_workers + .iter() + .map(|w| w.url().to_string()) + .collect(); + info!( "Waiting for {} workers to become healthy. Unhealthy: {:?}", unhealthy_workers.len(), - unhealthy_workers + unhealthy_urls ); - let unhealthy_workers_to_check = workers + // 3. Check health of all unhealthy workers in parallel + let health_check_futures: Vec<_> = unhealthy_workers .iter() - .filter(|w| !w.is_healthy()) - .cloned() - .collect::>(); - - for worker in unhealthy_workers_to_check { - let url = worker.url().to_string(); - match worker.check_health_async().await { - Ok(_) => { - if !worker.is_healthy() { - worker.set_healthy(true); - debug!("Worker {} now healthy after health check", url); + .map(|worker| { + let w = worker.clone(); + let url = worker.url().to_string(); + async move { + match w.check_health_async().await { + Ok(_) => { + w.set_healthy(true); + debug!("Worker {} now healthy", url); + } + Err(e) => { + debug!("Worker {} health check failed: {}", url, e); + } } } - Err(e) => { - debug!("Worker {} health check failed: {}", url, e); - } - } + }) + .collect(); + + future::join_all(health_check_futures).await; + + // 4. Check if all workers are now healthy after health checks + let still_unhealthy: Vec<_> = workers.iter().filter(|w| !w.is_healthy()).collect(); + + // 5. If all workers are now healthy, return immediately without sleeping + if still_unhealthy.is_empty() { + let healthy_urls: Vec<_> = workers.iter().map(|w| w.url().to_string()).collect(); + info!( + "All {} workers are healthy: {:?}", + workers.len(), + healthy_urls + ); + return Ok(()); } + // 6. Otherwise, sleep before next iteration tokio::time::sleep(check_interval).await; } }